import numpy as np
## 构造unilim bert seq2seq的mask矩阵原理
a = np.array([[0,0,0,1,1,1],[0,0,1,1,1,1]])
idxs = np.cumsum(a, axis=1)
print(idxs)
print("####")
print(idxs[:, None, :])
print("####")
print(idxs[:, :, None])
print("####")
mask = idxs[:, None, :] <= idxs[:, :, None]
print(mask[:,None])